/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

 #include <iostream>
 #include <fstream>
 #include <vector>
 #include <gtest/gtest.h>
 #include "log/log.h"
 #include "ut_op_common.h"
 #include "register/op_impl_registry.h"
 #include "platform/platform_infos_def.h"
 #include "ut_op_util.h"
 #include "kernel_run_context_facker.h"
 #include "test_cube_util.h"
 #include "exe_graph/runtime/storage_format.h"
 #include "exe_graph/runtime/storage_shape.h"
 #include "../../../op_host/swi_glu_quant_tiling.h"
 #include "../../../op_graph/swi_glu_quant_proto.h"
 #include "tiling/platform/platform_ascendc.h"

 using namespace ut_util;
 using namespace std;
 using namespace ge;

 class SwiGluQuantTiling : public testing::Test {
  protected:
   static void SetUpTestCase() {
     std::cout << "SwiGluQuantTiling SetUp" << std::endl;
   }

   static void TearDownTestCase() {
     std::cout << "SwiGluQuantTiling TearDown" << std::endl;
   }
 };

 TEST_F(SwiGluQuantTiling, swi_glu_quant_tiling_002) {
   gert::StorageShape input_shape = {{256, 512}, {256, 512}};
   gert::StorageShape smooth_scales_shape = {{16, 256}, {16, 256}};
   gert::StorageShape offset_shape = {{16}, {16}};
   gert::StorageShape groupIndex_shape = {{16}, {16}};
   gert::StorageShape out_shape = {{256, 256}, {256, 256}};
   gert::StorageShape scale_shape = {{256}, {256}};

   string compile_info_string = R"({
          "hardware_info": {"BT_SIZE": 0, "load3d_constraints": "1",
                            "Intrinsic_fix_pipe_l0c2out": false, "Intrinsic_data_move_l12ub": true, "Intrinsic_data_move_l0c2ub": true, "Intrinsic_data_move_out2l1_nd2nz": false,
                            "UB_SIZE": 196608, "L2_SIZE": 33554432, "L1_SIZE": 524288,
                            "L0A_SIZE": 65536, "L0B_SIZE": 65536, "L0C_SIZE": 131072,
                            "CORE_NUM": 40}
                            })";
   map<string, string> soc_infos;
   map<string, string> aicore_spec;
   map<string, string> intrinsics;
   GetPlatFormInfos(compile_info_string.c_str(), soc_infos, aicore_spec, intrinsics);
   std::cout << "GetPlatFormInfos" << soc_infos.size() << " " << aicore_spec.size() << " " << intrinsics.size()
             << std::endl;

   // platform info
   fe::PlatFormInfos platform_info;
   platform_info.Init();
   // compile info
   optiling::SwiGluQuantCompileInfo compile_info;

   std::string op_type("SwiGluQuant");
   ASSERT_NE(gert::OpImplRegistry::GetInstance().GetOpImpl(op_type.c_str()), nullptr);
   auto tiling_func = gert::OpImplRegistry::GetInstance().GetOpImpl(op_type.c_str())->tiling;
   auto tiling_parse_func = gert::OpImplRegistry::GetInstance().GetOpImpl(op_type.c_str())->tiling_parse;

   // tilingParseFunc simulate
   auto kernel_holder =
       gert::KernelRunContextFaker()
           .KernelIONum(2, 1)
           .Inputs({const_cast<char*>(compile_info_string.c_str()), reinterpret_cast<void*>(&platform_info)})
           .Outputs({&compile_info})
           .Build();

   ASSERT_TRUE(kernel_holder.GetContext<gert::TilingParseContext>()->GetPlatformInfo()->Init());
   kernel_holder.GetContext<gert::TilingParseContext>()->GetPlatformInfo()->SetPlatformRes("SoCInfo", soc_infos);
   kernel_holder.GetContext<gert::TilingParseContext>()->GetPlatformInfo()->SetPlatformRes("AICoreSpec", aicore_spec);
   kernel_holder.GetContext<gert::TilingParseContext>()->GetPlatformInfo()->SetCoreNumByCoreType("AICore");
   kernel_holder.GetContext<gert::TilingParseContext>()->GetPlatformInfo()->SetPlatformRes("AICoreintrinsicDtypeMap",
                                                                                           intrinsics);

   ASSERT_EQ(tiling_parse_func(kernel_holder.GetContext<gert::KernelContext>()), ge::GRAPH_SUCCESS);

   // tilingFunc simulate
   auto param = gert::TilingData::CreateCap(4096);
   auto workspace_size_holer = gert::ContinuousVector::Create<size_t>(4096);
   auto ws_size = reinterpret_cast<gert::ContinuousVector*>(workspace_size_holer.get());
   ASSERT_NE(param, nullptr);

   auto holder = gert::TilingContextFaker()
                     .SetOpType("SwiGluQuant")
                     .NodeIoNum(4, 2)
                     .IrInstanceNum({1, 1, 1, 1})
                     .InputShapes({&input_shape, &smooth_scales_shape, &offset_shape, &groupIndex_shape})
                     .OutputShapes({&out_shape, &scale_shape})
                     .CompileInfo(&compile_info)
                     .PlatformInfo(reinterpret_cast<char*>(&platform_info))
                     .NodeInputTd(0, ge::DT_BF16, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeInputTd(1, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeInputTd(2, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeInputTd(3, ge::DT_INT32, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeOutputTd(0, ge::DT_INT8, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeOutputTd(1, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND)
                     .NodeAttrs({{"activate_left", Ops::NN::AnyValue::CreateFrom<bool>(false)},
                                 {"quant_mode", Ops::NN::AnyValue::CreateFrom<std::string>("dynamic")},
                                 {"group_list_type", Ops::NN::AnyValue::CreateFrom<int64_t>(0)},
                                 {"dst_type", Ops::NN::AnyValue::CreateFrom<int64_t>(2)}})
                     .TilingData(param.get())
                     .Workspace(ws_size)
                     .Build();

   gert::TilingContext* tiling_context = holder.GetContext<gert::TilingContext>();
   ASSERT_NE(tiling_context, nullptr);
   ASSERT_NE(tiling_context->GetPlatformInfo(), nullptr);
   holder.GetContext<gert::TilingContext>()->GetPlatformInfo()->SetPlatformRes("SoCInfo", soc_infos);
   holder.GetContext<gert::TilingContext>()->GetPlatformInfo()->SetPlatformRes("AICoreSpec", aicore_spec);
   holder.GetContext<gert::TilingContext>()->GetPlatformInfo()->SetCoreNumByCoreType("AICore");
   holder.GetContext<gert::TilingContext>()->GetPlatformInfo()->SetPlatformRes("AICoreintrinsicDtypeMap", intrinsics);

   // workspaces nullptr return failed
   EXPECT_EQ(tiling_func(tiling_context), ge::GRAPH_SUCCESS);

   auto tiling_key = tiling_context->GetTilingKey();
   ASSERT_EQ(tiling_key, 206);
 }